# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-unsafe

from __future__ import absolute_import, division, print_function, unicode_literals

import itertools
import math
import unittest

import thrift.util.randomizer as randomizer
from fuzz import ttypes
from thrift import Thrift


class TestRandomizer:
    iterations = 1024

    def get_randomizer(self, ttypes, spec_args, constraints):
        state = randomizer.RandomizerState({"global_constraint": 100})
        return state.get_randomizer(ttypes, spec_args, constraints)


class TestBoolRandomizer(unittest.TestCase, TestRandomizer):
    def test_always_true(self):
        cls = self.__class__
        constraints = {"p_true": 1.0}
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in range(cls.iterations):
            self.assertTrue(gen.generate())

    def test_always_false(self):
        cls = self.__class__
        constraints = {"p_true": 0.0}
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in range(cls.iterations):
            self.assertFalse(gen.generate())

    def test_seeded(self):
        cls = self.__class__
        constraints = {"seeds": [True], "p_random": 0, "p_fuzz": 0}
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in range(cls.iterations):
            self.assertTrue(gen.generate())

    def test_int_seeded(self):
        cls = self.__class__
        constraints = {"seeds": [1], "p_random": 0, "p_fuzz": 0}
        gen = self.get_randomizer(Thrift.TType.BOOL, None, constraints)

        for _ in range(cls.iterations):
            self.assertTrue(gen.generate())


class TestEnumRandomizer(unittest.TestCase, TestRandomizer):
    def test_always_valid(self):
        cls = self.__class__
        constraints = {"p_invalid": 0}
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in range(cls.iterations):
            self.assertIn(gen.generate(), ttypes.Color._VALUES_TO_NAMES)

    def test_never_valid(self):
        cls = self.__class__
        constraints = {"p_invalid": 1}
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in range(cls.iterations):
            self.assertNotIn(gen.generate(), ttypes.Color._VALUES_TO_NAMES)

    def test_choices(self):
        cls = self.__class__
        choices = ["RED", "BLUE", "BLACK"]
        constraints = {"choices": choices}
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            name = ttypes.Color._VALUES_TO_NAMES[val]
            self.assertIn(name, choices)

    def test_seeded(self):
        cls = self.__class__
        seeds = ["RED", "BLUE", "BLACK"]
        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}
        gen = self.get_randomizer(Thrift.TType.I32, ttypes.Color, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            name = ttypes.Color._VALUES_TO_NAMES[val]
            self.assertIn(name, seeds)


class TestIntRandomizer(TestRandomizer):
    @classmethod
    def _one_bit_flipped(cls, a, b):
        """Return true if a and b differ at at most one bit position"""
        diff = a ^ b  # Bits set to 1 where a and b differ
        # If diff has only one `1` bit, subtracting one will clear that bit
        # Otherwise, the most significant 1 will not be cleared
        return 0 == (diff & (diff - 1))

    @classmethod
    def _within_delta(cls, a, b, delta):
        return delta >= abs(a - b)

    @classmethod
    def _is_fuzzed_single_seed(cls, seed, fuzzed, delta):
        return cls._one_bit_flipped(seed, fuzzed) or cls._within_delta(
            seed, fuzzed, delta
        )

    @classmethod
    def is_fuzzed(cls, seeds, fuzzed, delta=None):
        """Check whether `fuzzed` could have been
        generated by fuzzing any element of `seeds`"""
        if delta is None:
            # Find the default `fuzz_max_delta` constraint
            randomizer_cls = randomizer.RandomizerState().get_randomizer(
                cls.ttype, None, {}
            )
            delta = randomizer_cls.default_constraints["fuzz_max_delta"]

        return any(cls._is_fuzzed_single_seed(seed, fuzzed, delta) for seed in seeds)

    @property
    def min(self):
        cls = self.__class__
        n_bits = cls.n_bits
        return -(2 ** (n_bits - 1))

    @property
    def max(self):
        cls = self.__class__
        n_bits = cls.n_bits
        return (2 ** (n_bits - 1)) - 1

    def testInRange(self):
        cls = self.__class__
        ttype = cls.ttype
        min_ = self.min
        max_ = self.max
        gen = self.get_randomizer(ttype, None, {})
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertGreaterEqual(val, min_)
            self.assertLessEqual(val, max_)

    def testConstant(self):
        cls = self.__class__
        ttype = cls.ttype

        constant = 17

        constraints = {"choices": [constant]}

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertEqual(val, constant)

    def testChoices(self):
        cls = self.__class__
        ttype = cls.ttype

        choices = [11, 17, 19]

        constraints = {"choices": choices}

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testRange(self):
        cls = self.__class__
        ttype = cls.ttype

        range_ = [45, 55]

        constraints = {"range": range_}

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertGreaterEqual(val, range_[0])
            self.assertLessEqual(val, range_[1])

    def testRangeChoicePrecedence(self):
        cls = self.__class__
        ttype = cls.ttype

        range_ = [45, 55]
        choices = [11, 17, 19]

        constraints = {"range": range_, "choices": choices}

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testSeeded(self):
        cls = self.__class__
        ttype = cls.ttype

        seeds = [11, 17, 19]

        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

    def testFuzzing(self):
        cls = self.__class__
        ttype = cls.ttype
        min_ = self.min
        max_ = self.max

        max_delta = 4
        seeds = [0, self.max - int(max_delta / 2), self.min + int(max_delta / 2)]

        constraints = {
            "seeds": seeds,
            "p_random": 0,
            "p_fuzz": 1,
            "fuzz_max_delta": max_delta,
        }

        gen = self.get_randomizer(ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertGreaterEqual(val, min_)
            self.assertLessEqual(val, max_)
            self.assertTrue(cls.is_fuzzed(seeds, val, max_delta))


class TestByteRandomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.BYTE
    n_bits = 8


class TestI16Randomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.I16
    n_bits = 16


class TestI32Randomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.I32
    n_bits = 32


class TestI64Randomizer(TestIntRandomizer, unittest.TestCase):
    ttype = Thrift.TType.I64
    n_bits = 64


class TestFloatRandomizer(TestRandomizer):
    @property
    def randomizer_cls(self):
        return self.__class__.randomizer_cls

    def testZero(self):
        cls = self.__class__
        constraints = {"p_zero": 1, "p_unreal": 0}
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertEqual(val, 0.0)

    def testNonZero(self):
        cls = self.__class__
        constraints = {
            "p_zero": 0,
        }
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertNotEqual(val, 0.0)

    def testUnreal(self):
        cls = self.__class__
        constraints = {"p_unreal": 1}
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertTrue(math.isnan(val) or math.isinf(val))

    def testReal(self):
        cls = self.__class__
        constraints = {"p_unreal": 0}
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertFalse(math.isnan(val) or math.isinf(val))

    def testConstant(self):
        cls = self.__class__
        constant = 77.2
        constraints = {"mean": constant, "std_deviation": 0, "p_unreal": 0, "p_zero": 0}
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertEqual(val, constant)

    def testChoices(self):
        cls = self.__class__
        choices = [float("-inf"), 0.0, 13.37]
        constraints = {"choices": choices}
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testSeeded(self):
        cls = self.__class__
        seeds = [float("-inf"), 0.0, 13.37]
        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

    def testIntSeeded(self):
        cls = self.__class__
        seeds = [1, 2, 3]
        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}
        gen = self.get_randomizer(self.randomizer_cls.ttype, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)


class TestSinglePrecisionRandomizer(TestFloatRandomizer, unittest.TestCase):
    randomizer_cls = randomizer.SinglePrecisionFloatRandomizer
    ttype = Thrift.TType.FLOAT


class TestDoublePrecisionRandomizer(TestFloatRandomizer, unittest.TestCase):
    randomizer_cls = randomizer.DoublePrecisionFloatRandomizer
    ttype = Thrift.TType.DOUBLE


class TestStringRandomizer(TestRandomizer, unittest.TestCase):
    def testInRange(self):
        cls = self.__class__
        ascii_min, ascii_max = randomizer.StringRandomizer.ascii_range

        gen = self.get_randomizer(Thrift.TType.STRING, None, {})
        for _ in range(cls.iterations):
            val = gen.generate()
            for char in val:
                self.assertTrue(ascii_min <= ord(char) <= ascii_max)

    def testEmpty(self):
        cls = self.__class__

        constraints = {"mean_length": 0}
        gen = self.get_randomizer(Thrift.TType.STRING, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            for _char in val:
                self.assertEqual(0, len(val))

    def testChoices(self):
        cls = self.__class__

        choices = ["foo", "bar"]
        constraints = {"choices": choices}
        gen = self.get_randomizer(Thrift.TType.STRING, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, choices)

    def testSeeded(self):
        cls = self.__class__

        seeds = ["foo", "bar"]
        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}
        gen = self.get_randomizer(Thrift.TType.STRING, None, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)


class TestListRandomizer(TestRandomizer, unittest.TestCase):
    @classmethod
    def _has_extra_element(cls, shorter, longer):
        """Return True if the list `longer` can be created by inserting
        one element into `shorter`"""
        assert len(shorter) + 1 == len(longer)
        found_inserted = False
        for i, elem in enumerate(shorter):
            if found_inserted:
                # Same element should be at greater index in `longer`
                if elem != longer[i + 1]:
                    return False
            else:
                # Should be at the same index; otherwise we have found
                # the inserted element
                if elem != longer[i]:
                    found_inserted = True
        return True

    @classmethod
    def _is_fuzzed_single_seed(cls, seed, fuzzed):
        """Check whether `fuzzed` could have been generated by fuzzing `seed`

        Requires that the elements of the list are of type i32
        """
        old_len = len(seed)
        new_len = len(fuzzed)
        len_delta = new_len - old_len
        if len_delta == -1:
            # An element was deleted. All other elements should be unaltered
            return cls._has_extra_element(fuzzed, seed)
        elif len_delta == 0:
            # An element was fuzzed
            different_elements = []
            for old, new in zip(seed, fuzzed):
                if old != new:
                    different_elements.append((old, new))
            if len(different_elements) == 0:
                # Fuzzed element was not changed
                return True
            elif len(different_elements) == 1:
                # Element was fuzzed
                seed_elem, fuzzed_elem = different_elements[0]
                return TestI32Randomizer.is_fuzzed([seed_elem], fuzzed_elem)
            else:
                return False
        elif len_delta == 1:
            # An element was inserted. All other elements should be unaltered
            return cls._has_extra_element(seed, fuzzed)
        else:
            return False

    @classmethod
    def is_fuzzed(cls, seeds, fuzzed):
        """Check whether `fuzzed` could have been generated by
        fuzzing any element of `seeds`. Requires that the elements
        of the list are of type i32"""
        return any(cls._is_fuzzed_single_seed(seed, fuzzed) for seed in seeds)

    def testEmpty(self):
        cls = self.__class__

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.I32, None)  # Elements are i32
        constraints = {"mean_length": 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertEqual(len(val), 0)

    def testMaxLength(self):
        cls = self.__class__

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.I32, None)  # Elements are i32
        constraints = {"mean_length": 100, "max_length": 99}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        # Test to make sure that max length is enforced.
        #
        # Generate a lot of lists that should never be over 99 long
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertLessEqual(len(val), 99)

    def testElementConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.BOOL, None)
        constraints = {"element": {"p_true": 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            for elem in val:
                self.assertFalse(elem)

    def testSeeded(self):
        cls = self.__class__

        seeds = [[True, False, True], [False, False], []]

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.BOOL, None)
        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)

    def testFuzzed(self):
        cls = self.__class__

        seeds = [[], [1], [1, 2], [1, 2, 3], [1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 1]]

        ttype = Thrift.TType.LIST
        spec_args = (Thrift.TType.I32, None)
        constraints = {
            "seeds": seeds,
            "p_random": 0,
            "p_fuzz": 1,
            "element": {"p_random": 0, "p_fuzz": 1},
        }

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertTrue(
                cls.is_fuzzed(seeds, val), msg="val %s not generated by fuzzing" % val
            )


class TestSetRandomizer(TestRandomizer, unittest.TestCase):
    def testEmpty(self):
        cls = self.__class__

        ttype = Thrift.TType.SET
        spec_args = (Thrift.TType.I32, None)  # Elements are i32
        constraints = {"mean_length": 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertEqual(len(val), 0)

    def testElementConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.SET
        spec_args = (Thrift.TType.BOOL, None)
        constraints = {"element": {"p_true": 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            for elem in val:
                self.assertFalse(elem)

    def testSeeded(self):
        cls = self.__class__

        seeds = [{1, 2, 3}, set(), {-1, -2, -3}]

        ttype = Thrift.TType.SET
        spec_args = (Thrift.TType.I32, None)
        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)


class TestMapRandomizer(TestRandomizer, unittest.TestCase):
    def testEmpty(self):
        cls = self.__class__

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.I32, None, Thrift.TType.I16, None)
        constraints = {"mean_length": 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertEqual(len(val), 0)

    def testKeyConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.BOOL, None, Thrift.TType.I16, None)
        constraints = {"key": {"p_true": 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            for elem in val:
                self.assertFalse(elem)

    def testValConstraints(self):
        cls = self.__class__

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.I32, None, Thrift.TType.I32, ttypes.Color)
        constraints = {"value": {"p_invalid": 0}}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            for elem in val.values():
                self.assertIn(elem, ttypes.Color._VALUES_TO_NAMES)

    def testSeeded(self):
        cls = self.__class__

        seeds = [{1: "foo", 2: "fee", 3: "fwee"}, {}, {0: ""}]

        ttype = Thrift.TType.MAP
        spec_args = (Thrift.TType.I32, None, Thrift.TType.STRING, None)
        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}

        gen = self.get_randomizer(ttype, spec_args, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIn(val, seeds)


class TestStructRandomizer(TestRandomizer, unittest.TestCase):
    def get_spec_args(self, ttype):
        # (ttype, thrift_spec, is_union)
        return (ttype, ttype.thrift_spec, ttype.isUnion())

    def struct_randomizer(self, ttype=None, constraints=None):
        if ttype is None:
            ttype = self.__class__.ttype
        return self.get_randomizer(
            Thrift.TType.STRUCT, self.get_spec_args(ttype), constraints or {}
        )

    def testNoFieldConstraints(self):
        cls = self.__class__

        constraints = {
            "per_field": {"a": {"p_include": 0.0}, "b": {"p_include": 0.0}},
            "p_include": 1.0,
        }

        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNone(val.a)
            self.assertIsNone(val.b)

    def testOneTrueFieldConstraints(self):
        cls = self.__class__

        constraints = {
            "per_field": {"b": {"p_include": 0.0}},
            "p_include": 1.0,
            "a": {"p_true": 1.0},
        }

        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertTrue(val.a)
            self.assertIsNone(val.b)

    def testStructSeed(self):
        cls = self.__class__
        constraints = {
            "|Rainbow": {
                "seeds": [ttypes.Rainbow(colors=[ttypes.Color.RED])],
            },
            "|NestedStructs": {
                "seeds": [
                    ttypes.NestedStructs(
                        rainbow=ttypes.Rainbow(colors=[ttypes.Color.ORANGE])
                    )
                ],
            },
        }
        gen = self.struct_randomizer(ttypes.NestedStructs, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)

    def testNestedConstraints(self):
        cls = self.__class__

        constraints = {
            "per_field": {"b": {"p_include": 0.0}},
            "p_include": 1.0,
            "a": {"p_true": 1.0},
            "c": {
                "p_include": 1.0,
                "per_field": {"b": {"p_include": 0.0}},
            },
        }

        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertTrue(val.a)
            self.assertIsNone(val.b)
            self.assertIsNotNone(val.c.a)
            self.assertIsNone(val.c.b)

    def testEmptyUnion(self):
        cls = self.__class__
        constraints = {}
        gen = self.struct_randomizer(ttypes.EmptyUnion, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            # Because the enum has no valid fields it's hard to generate
            # a reasonable value. So returning None from the
            # randomizer is reasonable.
            self.assertIsNone(val)

    def testStructContainingDefaultUnion(self):
        cls = self.__class__
        constraints = {}
        gen = self.struct_randomizer(ttypes.NumberUnionStruct, constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)

    def testSubRanomizersHaveDefaults(self):
        # Have constraints with a field that is never
        # used and won't come from the existing defaults.
        # Then make sure that constraint for StructWithOptionals
        # doesn't go everywhere.
        targeted_key = "targeted_constraint"
        constraints = {targeted_key: 100}
        gen = self.struct_randomizer(ttypes.StructWithOptionals, constraints)

        # Yes this is pretty ugly as we have
        # to reach into a undercode field, but it
        # does show what constraints get propagated and what don't
        for _field_name, data in gen._field_rules.items():
            # The targeted key will only apply to the first randomizer
            self.assertNotIn(targeted_key, data["randomizer"].constraints.keys())

            # The global one is propagated to everything.
            self.assertIn("global_constraint", data["randomizer"].constraints)


class TestUnionRandomizer(TestStructRandomizer, unittest.TestCase):
    ttype = ttypes.IntUnion

    def testAlwaysInclude(self):
        cls = self.__class__
        constraints = {"p_include": 1}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            # Check that field is nonzero, indicating a field is set
            self.assertNotEqual(val.field, 0)

    def testNeverInclude(self):
        cls = self.__class__
        constraints = {"p_include": 0}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            # Check that field is zero, indicating no fields are set
            self.assertIsNone(
                val,
                (
                    "Because there's no way to add fields of a "
                    "union there should be no way to create the union."
                ),
            )

    def testSeededFuzz(self):
        cls = self.__class__
        seeds = [ttypes.IntUnion(a=20), ttypes.IntUnion(b=40)]
        constraints = {"seeds": seeds, "p_random": 0}
        gen = self.struct_randomizer(constraints=constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(
                val,
                (
                    "The union should always be created. "
                    "We don't know the expected values, "
                    "just that they exist"
                ),
            )

    def testSeeded(self):
        cls = self.__class__

        seeds = [
            {"a": 2},
            {"b": 4},
            ttypes.IntUnion(a=2),
            ttypes.IntUnion(b=4),
        ]

        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}

        gen = self.struct_randomizer(constraints=constraints)

        def is_seed(val):
            if val.field == 1:
                return val.value == 2
            elif val.field == 2:
                return val.value == 4
            return False

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertTrue(
                is_seed(val), msg="Not a seed: %s (%s)" % (val, val.__dict__)
            )


class TestListStructRandomizer(TestStructRandomizer, unittest.TestCase):
    """
    Verify that this struct type is generated correctly

    struct ListStruct {
        1: list<bool> a;
        2: list<i16> b;
        3: list<double> c;
        4: list<string> d;
        5: list<list<i32>> e;
        6: list<map<i32, i32>> f;
        7: list<set<string>> g;
    }
    """

    ttype = ttypes.ListStruct

    def testGeneration(self):
        gen = self.struct_randomizer()
        val = gen.generate()

        if val.a is not None:
            self.assertIsInstance(val.a, list)
            for elem in val.a:
                self.assertIsInstance(elem, bool)

        if val.b is not None:
            self.assertIsInstance(val.b, list)
            for elem in val.b:
                self.assertIsInstance(elem, int)

        if val.c is not None:
            self.assertIsInstance(val.c, list)
            for elem in val.c:
                self.assertIsInstance(elem, float)

        if val.d is not None:
            self.assertIsInstance(val.d, list)
            for elem in val.d:
                self.assertIsInstance(elem, str)

        if val.e is not None:
            self.assertIsInstance(val.e, list)
            for elem in val.e:
                self.assertIsInstance(elem, list)
                for sub_elem in elem:
                    self.assertIsInstance(sub_elem, int)

        if val.f is not None:
            self.assertIsInstance(val.f, list)
            for elem in val.f:
                self.assertIsInstance(elem, dict)
                for k, v in elem.items():
                    self.assertIsInstance(k, int)
                    self.assertIsInstance(v, int)

        if val.g is not None:
            self.assertIsInstance(val.g, list)
            for elem in val.g:
                self.assertIsInstance(elem, set)
                for sub_elem in elem:
                    self.assertIsInstance(sub_elem, str)

    def testFieldConstraints(self):
        cls = self.__class__

        constraints = {
            "p_include": 1.0,
            "a": {"element": {"p_true": 1.0}},
            "b": {"mean_length": 0.0},
            "d": {"element": {"mean_length": 0.0}},
            "e": {"element": {"mean_length": 0.0}},
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.a)
            self.assertIsNotNone(val.b)
            self.assertIsNotNone(val.c)
            self.assertIsNotNone(val.d)
            self.assertIsNotNone(val.e)
            self.assertIsNotNone(val.f)
            self.assertIsNotNone(val.g)

            for elem in val.a:
                self.assertTrue(elem)

            self.assertEqual(len(val.b), 0)

            for elem in val.d:
                self.assertEqual(len(elem), 0)

            for elem in val.e:
                self.assertEqual(len(elem), 0)

    def testSeeded(self):
        seeds = [
            {
                "a": [True, False, False],
                "b": [1, 2, 3],
                "c": [1.2, 2.3],
                "d": ["foo", "bar"],
                "e": [[]],
                "f": [{1: 2}, {3: 4}],
                "g": [{"foo", "bar"}],
            }
        ]

        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 0}

        gen = self.struct_randomizer(constraints=constraints)
        val = gen.generate()
        for key, expected in seeds[0].items():
            self.assertEqual(
                expected, getattr(val, key, None), msg="%s, %s" % (val, dir(val))
            )

    def testFuzz(self):
        cls = self.__class__

        seeds = [
            {
                "a": [True, False, False],
                "b": [1, 2, 3],
                "c": [1.2, 2.3],
                "d": ["foo", "bar"],
                "e": [[]],
                "f": [{1: 2}, {3: 4}],
                "g": [{"foo", "bar"}],
            }
        ]

        constraints = {"seeds": seeds, "p_random": 0, "p_fuzz": 1}

        gen = self.struct_randomizer(constraints=constraints)
        for _ in range(cls.iterations):
            val = gen.generate()
            n_different = 0
            for key, expected in seeds[0].items():
                if expected != getattr(val, key, None):
                    n_different += 1
            self.assertLessEqual(n_different, 1)

    def testBoolTypeConstraints(self):
        cls = self.__class__

        constraints = {"p_include": 1.0, "|bool": {"p_true": 1.0}}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.a)
            for elem in val.a:
                self.assertTrue(elem)

    def testDoubleTypeConstraints(self):
        cls = self.__class__
        constant = 11.1

        constraints = {
            "p_include": 1.0,
            "|double": {
                "mean": constant,
                "std_deviation": 0,
                "p_unreal": 0,
                "p_zero": 0,
            },
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.c)
            for elem in val.c:
                self.assertEqual(elem, constant)

    def testListTypeConstraints(self):
        """
        We have an |i32 rule and a |list<i32> rule,
        The |list<i32> rule is more specific and overrides the |i32 rule
        """
        cls = self.__class__

        int_choices = [1, 2, 3]
        list_int_choices = [4, 5, 6]

        constraints = {
            "p_include": 1.0,
            "|list<i32>": {"element": {"choices": list_int_choices}},
            "|i32": {"choices": int_choices},
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.e)
            # Iterate through each i32 element in list<list<i32>>
            for elem in itertools.chain(*val.e):
                self.assertIn(elem, list_int_choices)

    def testMapTypeConstraints(self):
        """
        ListStruct.f is type list<map<i32, i32>>
        We use two rules:
        |i32
        |map<i32, i32>.key

        The |i32 rule will be used on map values, and the .key rule
        will be used on map keys since it is more specific.
        """
        cls = self.__class__

        int_choices = [1, 2, 3]
        key_int_choices = [4, 5, 6]

        constraints = {
            "p_include": 1.0,
            "|map<i32, i32>": {"key": {"choices": key_int_choices}},
            "|i32": {"choices": int_choices},
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertIsNotNone(val)
            self.assertIsNotNone(val.f)
            for key in itertools.chain(*val.f):
                self.assertIn(key, key_int_choices)
            for val in itertools.chain(*map(dict.values, val.f)):
                self.assertIn(val, int_choices)


class TestNestedStruct(TestStructRandomizer, unittest.TestCase):
    ttype = ttypes.NestedStructs

    def testFuzz(self):
        """Check that only one subfield is randomized"""
        cls = self.__class__

        seed = {
            "ls": {
                "a": [True, True],
                "b": [1, 2, 3],
                "c": [1.2, 3.4],
                "d": ["foo"],
                "e": [[1], [2, 3]],
                "f": [{1: -1}, {-1: 1}],
                "g": [{"a", "b"}, set(), set()],
            },
            "rainbow": {
                "colors": ["YELLOW", "YELLOW", "GRAY"],
                "brightness": float("inf"),
            },
            "ints": {"a": 1000},
        }

        constraints = {
            "seeds": [seed],
            "p_random": 0,
            "p_fuzz": 1,
            "ls": {"p_random": 0},
            "rainbow": {"p_random": 0},
            "ints": {"p_random": 0},
        }

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()

            # To check the generated value against the seed,
            # convert enum numbers back into enum names
            val.rainbow.colors = [
                ttypes.Color._VALUES_TO_NAMES.get(c, c) for c in val.rainbow.colors
            ]

            n_different = 0
            differents = []

            # Compare actual (generated) struct to expected structs
            for struct_name in {"ls", "rainbow"}:
                actual_struct = getattr(val, struct_name)
                for key, expected_field in seed[struct_name].items():
                    actual_field = getattr(actual_struct, key, None)
                    if expected_field != actual_field:
                        n_different += 1
                        differents.append(
                            ("%s.%s" % (struct_name, key), expected_field, actual_field)
                        )

            # Fuzzed union should use the same field as the seed union (a=1)
            self.assertEqual(val.ints.field, 1)

            expected = seed["ints"]["a"]
            actual = val.ints.value
            if actual != expected:
                n_different += 1
                differents.append(("ints.a", expected, actual))

            message = ", ".join("%s: %s != %s" % diff for diff in differents)
            self.assertLessEqual(n_different, 1, msg=message)


class TestStructRecursion(TestStructRandomizer, unittest.TestCase):
    ttype = ttypes.BTree

    def max_depth(self, tree):
        """Find the maximum depth of a ttypes.BTree struct"""
        if tree is None:
            return 0

        if isinstance(tree, ttypes.BTreeBranch):
            tree = tree.child

        child_depths = [self.max_depth(child) for child in tree.children]
        if child_depths:
            max_child_depth = max(child_depths)
        else:
            max_child_depth = 0

        return 1 + max_child_depth

    def testDepthZero(self):
        constraints = {"max_recursion_depth": 0}

        gen = self.struct_randomizer(constraints=constraints)

        if not gen.constraints["max_recursion_depth"] == 0:
            raise ValueError(
                "Invalid recursion depth %d" % (gen.constraints["max_recursion_depth"])
            )

        val = gen.generate()

        self.assertEqual(self.max_depth(val), 0)

    def testDepthOne(self):
        cls = self.__class__
        constraints = {"max_recursion_depth": 1}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertLessEqual(self.max_depth(val), 1)

    def testDepthTwo(self):
        cls = self.__class__
        constraints = {"max_recursion_depth": 2}

        gen = self.struct_randomizer(constraints=constraints)

        for _ in range(cls.iterations):
            val = gen.generate()
            self.assertLessEqual(self.max_depth(val), 2)


if __name__ == "__main__":
    unittest.main()
